Q-learning (fonction de valeur des états-actions)

Présentation

Le Q-learning est une technique d'apprentissage par renforcement. Il utilise la fonction Q (fonction de valeur des états-actions) qui repose sur un tableau que l'on nomme la Q-table. Les index de cette Q-table représentent les différents état du système. Chaque état dispose d'un tableau d'actions où chacune d'elles possède une valeur correspondant à une estimation qualitative de cette l'action.

Démonstration

Entrainement de l'IA
0 0 0 1 2 3 4 5 6 7 0

États

  • Collisions aux alentours de la tête du serpent: (3 directions possibles = 2**3 = 8 états possibles).
    | a = a1 * 4 + a2 * 2 + a3 = ?
  • Direction de la pomme: (8 directions = 8 états possibles).
    | b = ?
  • État: (8 * 8 = 64 états possibles).
    | a * 8 + b = ?

Exploitation

  • Action: (max( Q-Table[ état ])).
    | Q-Table[ ? ][ ?, ?, ? ] = ?

L'Exploitation

La décision optimale est prise à partir de la Q-table. L'agent va y sélectionner l'action possédant la plus grande valeur parmi toutes les actions disponibles pour l'état actuel

L'Apprentissage

Les récompenses étant distribuées après chaque action de l'agent, c'est à ce moment là que la valeur de l'action est mise à jour dans la Q-table à l'aide de la fonction Q.

q_table[état][action] = (1 - alpha) * q_table[état][action] + alpha * (récompense + gamma * q_table[état_suivant][argmax(q_table[état_suivant])])

Le facteur d'actualisation (gamma) modifie l'importance de l'influence de la meilleure action à l'état suivant sur l'action courrante. Il oscille entre 0 et 1. Plus la valeur se rapprochent de 1, plus l'influence sera forte.

Code

				
					// LE MODULO (adapté pour jongler avec les orientations relatives).
					var modulo = function(a, b)
					{
						return ((a % b) + b) % b;
					};

					// LES ÉTATS.
					var State = function(){}; 

					State.prototype.get_state = function(agent, apple, check_wallCollision, check_tailCollision)
					{
						// Coordonnées de la tête du serpent.
						var x = agent.matrix[0][0];
						var y = agent.matrix[0][1];

						var st1 = this.get_surrounding(x, y, agent, check_wallCollision, check_tailCollision);
						var st2 = this.get_appleDirection(x, y, agent.orientation, apple);

						// Combiner les états (8*8 état possibles).
						return parseInt(st1 * 8 + st2);
					};

					State.prototype.get_surrounding = function(x, y, agent, check_wallCollision, check_tailCollision)
					{
						// Présence de collisions aux alentours de la tête du serpent pour les 3 actions (2**3).
						var surrounding = [0, 0, 0];

						for (var i = surrounding.length - 1; i >= 0; i--)
						{
							var new_ori = modulo((agent.orientation + i - 1), 4);
							var new_pos = [x + agent.moves[new_ori][0], y + agent.moves[new_ori][1]];
							surrounding[i] = (check_wallCollision(new_pos) || check_tailCollision(new_pos));
						}

						return surrounding[0] * 4 + surrounding[1] * 2 + surrounding[2];
					};

					State.prototype.get_appleDirection = function(x, y, orientation, apple)
					{
						// Direction de la pomme en fonction de la position et de l'orientatation du serpent (8**1).
						var state;
						var ax = apple[0];
						var ay = apple[1];

						// Nord.
						if (ax == x && ay < y)
						{
							state = modulo((0 + 2 * orientation), 8);
						}
						// Nord-Est.
						else if (ax > x && ay < y)
						{
							state = modulo((7 + 2 * orientation), 8);
						}
						// Est.
						else if (ax > x && ay == y)
						{
							state = modulo((6 + 2 * orientation), 8);
						}
						// Sud-Est.
						else if (ax > x && ay > y)
						{
							state = modulo((5 + 2 * orientation), 8);
						}
						// Sud.
						else if (ax == x && ay > y)
						{
							state = modulo((4 + 2 * orientation), 8);
						}
						// Sud-Ouest.
						else if (ax < x && ay > y)
						{
							state = modulo((3 + 2 * orientation), 8);
						}
						// Ouest.
						else if (ax < x && ay == y)
						{
							state = modulo((2 + 2 * orientation), 8);
						}
						// Nord-Ouest.
						else if (ax < x && ay < y)
						{
							state = modulo((1 + 2 * orientation), 8);
						}

						return state;
					};

					// L'ENVIRONNEMENT.
					var Env = function(agent, cols_length, rows_length)
					{
						this.agent = agent;
						this.cols_length = cols_length;
						this.rows_length = rows_length;
						this.apple;
						this.time;
					};

					Env.prototype.restart = function()
					{
						this.time = 0;
						this.agent.reset(this.cols_length, this.rows_length);
						this.reset_apple();
					};

					Env.prototype.step = function(action)
					{
						this.time += 1;
						this.agent.update_orientation(action);
						return this.update(this.agent.get_newPosition());
					};

					Env.prototype.update = function(new_pos)
					{
						if (!this.check_wallCollision(new_pos) && !this.check_tailCollision(new_pos) && this.time < (this.rows_length * this.cols_length))
						{
							// Mettre la position de la tête à jour.
							this.agent.matrix.unshift(new_pos);

							// Pomme mangée.
							if (this.apple[0] == this.agent.matrix[0][0] && this.apple[1] == this.agent.matrix[0][1])
							{
								// Si la taille du serpent est égale au nombre de cases dans l'aire de jeu la partie est gagnée.
								if (this.agent.matrix.length == this.rows_length * this.cols_length)
								{
									// Partie gagnée - grosse récompense positive.
									return {reward: this.rows_length * this.cols_length, done: true};
								}

								this.time = 0;
								// Trouver une nouvelle position aléatoire pour la pomme.
								this.reset_apple();
								// Pomme mangée - récompense positive.
								return {reward: 1, done: false};
							}

							// Effacer l'ancienne position du bout de la queue.
							this.agent.matrix.pop(this.agent, this.cols_length, this.rows_length);
							// Mouvement du serpend - pas de récompense.
							return {reward: 0, done: false};
						}
						// Partie perdue - récompense négative.
						return {reward: -1, done: true};
					};

					Env.prototype.reset_apple = function()
					{
						var grid = [];
						var agent = JSON.stringify(this.agent.matrix);
						for (var c = 0; c < this.cols_length; c++)
						{
							for (var r = 0; r < this.rows_length; r++)
							{
								// Inclure uniquement les cases libres.
								if (agent.indexOf('['+c+','+r+']') === -1)
								{
									grid.push([c, r]);
								}
							}
						}
						this.apple = grid[Math.floor(Math.random() * grid.length)];
					};

					Env.prototype.check_wallCollision = function(coordinates)
					{
						// Vérifier la collision entre des coordonnées et les murs exterieurs du plateau de jeu.
						if ((coordinates[0] >= 0 && coordinates[0] < this.cols_length) && (coordinates[1] >= 0 && coordinates[1] < this.rows_length))
						{
							return false;
						}
						return true;
					};

					Env.prototype.check_tailCollision = function(coordinates)
					{
						//Vérifier la collision entre des coordonnées et les différentes parties de la queue du serpent.
						for (var i = this.agent.matrix.length - 1; i > 0; i--)
						{
							var pos = this.agent.matrix[i];
							if (coordinates[0] == pos[0] && coordinates[1] == pos[1])
							{
								return true;
							}
						}
						return false;
					};

					// L'AGENT.
					var Agent = function(states_length, actions_length, defaultValue, epsilon, alpha, gamma)
					{
						// Le tableau de valeur des états-actions.
						this.q_table = this.init_qTable(states_length, actions_length, defaultValue);
						// Taux d'exploration.
						this.epsilon = epsilon;
						// Taux d'apprentissage.
						this.alpha = alpha;
						// Facteur d'actualisation.
						this.gamma = gamma;
						// Déplacement: haut, droite, bas, gauche.
						this.moves = [[0, -1], [1, 0], [0, 1], [-1, 0]];
						// 0: haut, 1: droite, 2: bas, 3: gauche.
						this.orientation;
						// Les coordonnées des différents éléments du serpent.
						this.matrix;
					};

					Agent.prototype.init_qTable = function(states_length, actions_length, defaultValue)
					{
						var q_table = [];
						for (var s = 0; s < states_length; s++)
						{
							q_table[s] = [];
							for (var a = 0; a < actions_length; a++)
							{
								q_table[s][a] = defaultValue;
							}
						}
						return q_table;
					};

					Agent.prototype.reset = function(cols_length, rows_length)
					{
						this.orientation = 0;
						// Position aléatoire.
						this.matrix = [[Math.floor(Math.random() * cols_length), Math.floor(Math.random() * rows_length)]];
					};

					Agent.prototype.update_orientation = function(action)
					{
						this.orientation = modulo(this.orientation + action - 1, 4);
					};

					Agent.prototype.get_newPosition = function()
					{
						return [this.matrix[0][0] + this.moves[this.orientation][0], this.matrix[0][1] + this.moves[this.orientation][1]];
					};

					Agent.prototype.choose_action = function(state)
					{
						// Exploration:
						if (Math.random() < this.epsilon)
						{
							return Math.floor(Math.random() * 3);
						}
						// Exploitation:
						else
						{
							// Choisir l'action possédant la plus grande valeur.
							return this.argmax(this.q_table[state]);
						}
					};

					Agent.prototype.argmax = function(array)
					{
						var max = -Infinity;
						var index;
						for (var i = array.length - 1; i >= 0; i--)
						{
							if (array[i] > max)
							{
								max = array[i];
								index = i;
							}
						}
						return index;
					};

					Agent.prototype.update_qTable = function(state, action, next_state, reward)
					{
						// La fonction Q ou fonction de valeur des états-actions.
						this.q_table[state][action] = (1 - this.alpha) * this.q_table[state][action] + this.alpha * (reward + this.gamma * this.q_table[next_state][this.argmax(this.q_table[next_state])]);
					};

					// LE Q-LEARNING.
					var Qlearning = function()
					{
						this.cols_length = 10;
						this.rows_length = 10;
						this.state = new State();
					};

					Qlearning.prototype.train = function(epoch, batch_size)
					{
						var agent = new Agent(64, 3, 0, 1, 0.1, 0.1);
						var env = new Env(agent, this.cols_length, this.rows_length);

						// Le nombre d'entrainement avec une diminution progressive d'epsilon et d'alpha.
						for (var e = 0; e < epoch; e++)
						{
							var scores = [];
							// Le nombre d'entrainement avec les valeurs d'epsilon et d'alpha actuelles.
					        for (var b = 0; b < batch_size; b++)
					        {
					            env.restart();
					            var state = this.state.get_state(env.agent, env.apple, env.check_wallCollision.bind(env), env.check_tailCollision.bind(env));
					            var done = false;
					            while (done === false)
					            {
					                var action = env.agent.choose_action(state);
					                var gameInfos = env.step(action);
					                var next_state = this.state.get_state(env.agent, env.apple, env.check_wallCollision.bind(env), env.check_tailCollision.bind(env));
					                var reward = gameInfos.reward;
					               	env.agent.update_qTable(state, action, next_state, reward);

									done = gameInfos.done;
									state = next_state;
					            }
					            scores.push(env.agent.matrix.length);
					        }
					        // Afficher les statistiques.
					       	this.display_results(scores, e, env.agent);
					        // Diminuer le taux d'exploration.
					       	env.agent.epsilon *= 0.95;
					    	// Diminuer le taux d'apprentissage.
					        env.agent.alpha *= 0.99;

						};
					};

					Qlearning.prototype.display_results = function(array, epoch, agent)
					{
						var output = 0;
						var length = array.length;
						for (var i = 0; i < length; i++)
						{
							output += array[i];
						}
						var average = Math.round(output / length);
						console.log(epoch + " | moyenne: " + average + ", max: " + array[agent.argmax(array)] + ", epsilon: " + Math.round(agent.epsilon * 1000) / 1000 + ", alpha: " + Math.round(agent.alpha * 1000) / 1000 + ", gamma: " + agent.gamma);
					};

					var demo = new Qlearning();
					demo.train(100, 100);